from __future__ import annotations

from typing import Tuple, List

import numpy as np
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch import Tensor
import cv2
from kornia.filters import gaussian_blur2d
from kornia.geometry.transform import resize
from kornia.morphology import erosion

from .util import pad_tensor_to_modulo


def isinstance_string(obj, class_name):
    return type(obj).__name__ == class_name


def move_to_device(obj, device):
    if isinstance(obj, torch.nn.Module):
        return obj.to(device)
    if torch.is_tensor(obj):
        return obj.to(device)
    if isinstance(obj, (tuple, list)):
        return [move_to_device(el, device) for el in obj]
    if isinstance(obj, dict):
        return {name: move_to_device(val, device) for name, val in obj.items()}
    raise ValueError(f"Unexpected type {type(obj)}")


def _pyrdown(im: torch.Tensor, downsize: tuple = None):
    """downscale the image"""
    if downsize is None:
        downsize = (im.shape[2] // 2, im.shape[3] // 2)
    assert im.shape[1] == 3, "Expected shape for the input to be (n,3,height,width)"
    im = gaussian_blur2d(im, kernel_size=(5, 5), sigma=(1.0, 1.0))
    im = F.interpolate(im, size=downsize, mode="bilinear", align_corners=False)
    return im


def _pyrdown_mask(
    mask: torch.Tensor,
    downsize: tuple = None,
    eps: float = 1e-8,
    blur_mask: bool = True,
    round_up: bool = True,
):
    """downscale the mask tensor

    Parameters
    ----------
    mask : torch.Tensor
        mask of size (B, 1, H, W)
    downsize : tuple, optional
        size to downscale to. If None, image is downscaled to half, by default None
    eps : float, optional
        threshold value for binarizing the mask, by default 1e-8
    blur_mask : bool, optional
        if True, apply gaussian filter before downscaling, by default True
    round_up : bool, optional
        if True, values above eps are marked 1, else, values below 1-eps are marked 0, by default True

    Returns
    -------
    torch.Tensor
        downscaled mask
    """

    if downsize is None:
        downsize = (mask.shape[2] // 2, mask.shape[3] // 2)
    assert mask.shape[1] == 1, "Expected shape for the input to be (n,1,height,width)"
    if blur_mask == True:
        mask = gaussian_blur2d(mask, kernel_size=(5, 5), sigma=(1.0, 1.0))
        mask = F.interpolate(mask, size=downsize, mode="bilinear", align_corners=False)
    else:
        mask = F.interpolate(mask, size=downsize, mode="bilinear", align_corners=False)
    if round_up:
        mask[mask >= eps] = 1
        mask[mask < eps] = 0
    else:
        mask[mask >= 1.0 - eps] = 1
        mask[mask < 1.0 - eps] = 0
    return mask


def _erode_mask(mask: torch.Tensor, ekernel: torch.Tensor = None, eps: float = 1e-8):
    """erode the mask, and set gray pixels to 0"""
    if ekernel is not None:
        mask = erosion(mask, ekernel)
        mask[mask >= 1.0 - eps] = 1
        mask[mask < 1.0 - eps] = 0
    return mask


def _l1_loss(
    pred: torch.Tensor,
    pred_downscaled: torch.Tensor,
    ref: torch.Tensor,
    mask: torch.Tensor,
    mask_downscaled: torch.Tensor,
    image: torch.Tensor,
    on_pred: bool = True,
):
    """l1 loss on src pixels, and downscaled predictions if on_pred=True"""
    loss = torch.mean(torch.abs(pred[mask < 1e-8] - image[mask < 1e-8]))
    if on_pred:
        loss += torch.mean(
            torch.abs(pred_downscaled[mask_downscaled >= 1e-8] - ref[mask_downscaled >= 1e-8])
        )
    return loss


def _infer(
    image: torch.Tensor,
    mask: torch.Tensor,
    forward_front: torch.nn.Module,
    forward_rears: List[torch.nn.Module],
    ref_lower_res: torch.Tensor,
    orig_shape: tuple,
    devices: list,
    scale_ind: int,
    n_iters: int = 15,
    lr: float = 0.002,
):
    """Performs inference with refinement at a given scale.

    Parameters
    ----------
    image : torch.Tensor
        input image to be inpainted, of size (1,3,H,W)
    mask : torch.Tensor
        input inpainting mask, of size (1,1,H,W)
    forward_front : torch.nn.Module
        the front part of the inpainting network
    forward_rears : torch.nn.Module
        the rear parts of the inpainting network
    ref_lower_res : torch.Tensor
        the inpainting at previous scale, used as reference image
    orig_shape : tuple
        shape of the original input image before padding
    devices : list
        list of available devices
    scale_ind : int
        the scale index
    n_iters : int, optional
        number of iterations of refinement, by default 15
    lr : float, optional
        learning rate, by default 0.002

    Returns
    -------
    torch.Tensor
        inpainted image
    """
    masked_image = image * (1 - mask)
    masked_image = torch.cat([masked_image, mask], dim=1)

    mask = mask.repeat(1, 3, 1, 1)
    if ref_lower_res is not None:
        ref_lower_res = ref_lower_res.detach()
    with torch.no_grad():
        # forward_front.eval()
        z: Tuple[torch.Tensor, torch.Tensor] = forward_front(masked_image)
        z1, z2 = z
    # Inference
    mask = mask.to(devices[-1])
    ekernel = torch.from_numpy(
        cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)).astype(bool)
    ).float()
    ekernel = ekernel.to(devices[-1])
    image = image.to(devices[-1])

    with torch.inference_mode(False):
        z1, z2 = z1.detach().to(devices[0]), z2.detach().to(devices[0])
        z1.requires_grad, z2.requires_grad = True, True

        # with torch.enable_grad():  # Force autograd tracking
        # z1 = z1.requires_grad_(True).clone().detach().requires_grad_(True)
        # z2 = z2.requires_grad_(True).clone().detach().requires_grad_(True)

        optimizer = Adam([z1, z2], lr=lr)

        pbar = range(n_iters)
        for idi in pbar:
            optimizer.zero_grad()
            input_feat = (z1, z2)
            for idd, forward_rear in enumerate(forward_rears):
                # forward_rear.train()

                output_feat = forward_rear(input_feat)
                if idd < len(devices) - 1:
                    midz1, midz2 = output_feat
                    midz1, midz2 = midz1.to(devices[idd + 1]), midz2.to(devices[idd + 1])
                    input_feat = (midz1, midz2)
                else:
                    pred = output_feat

            if ref_lower_res is None:
                break
            losses = {}
            ######################### multi-scale #############################
            # scaled loss with downsampler
            pred_downscaled = _pyrdown(pred[:, :, : orig_shape[0], : orig_shape[1]])
            mask_downscaled = _pyrdown_mask(
                mask[:, :1, : orig_shape[0], : orig_shape[1]], blur_mask=False, round_up=False
            )
            mask_downscaled = _erode_mask(mask_downscaled, ekernel=ekernel)
            mask_downscaled = mask_downscaled.repeat(1, 3, 1, 1)
            losses["ms_l1"] = _l1_loss(
                pred, pred_downscaled, ref_lower_res, mask, mask_downscaled, image, on_pred=True
            )

            loss = sum(losses.values())

            if idi < n_iters - 1:
                loss.backward()
                optimizer.step()
                del pred_downscaled
                del loss
                del pred

        inpainted = mask * pred + (1 - mask) * image
        inpainted = inpainted.detach().cpu()
        return inpainted


def _get_image_mask_pyramid(batch: dict, min_side: int, max_scales: int, px_budget: int):
    """Build the image mask pyramid

    Parameters
    ----------
    batch : dict
        batch containing image, mask, etc
    min_side : int
        minimum side length to limit the number of scales of the pyramid
    max_scales : int
        maximum number of scales allowed
    px_budget : int
        the product H*W cannot exceed this budget, because of resource constraints

    Returns
    -------
    tuple
        image-mask pyramid in the form of list of images and list of masks
    """

    assert batch["image"].shape[0] == 1, "refiner works on only batches of size 1!"

    h, w = batch["unpad_to_size"]
    h, w = h[0].item(), w[0].item()

    image = batch["image"][..., :h, :w]
    mask = batch["mask"][..., :h, :w]
    if h * w > px_budget:
        # resize
        ratio = np.sqrt(px_budget / float(h * w))
        h_orig, w_orig = h, w
        h, w = int(h * ratio), int(w * ratio)
        print(f"Original image too large for refinement! Resizing {(h_orig,w_orig)} to {(h,w)}...")
        image = resize(image, (h, w), interpolation="bilinear", align_corners=False)
        mask = resize(mask, (h, w), interpolation="bilinear", align_corners=False)
        mask[mask > 1e-8] = 1
    breadth = min(h, w)
    n_scales = min(1 + int(round(max(0, np.log2(breadth / min_side)))), max_scales)
    ls_images = []
    ls_masks = []

    ls_images.append(image)
    ls_masks.append(mask)

    for _ in range(n_scales - 1):
        image_p = _pyrdown(ls_images[-1])
        mask_p = _pyrdown_mask(ls_masks[-1])
        ls_images.append(image_p)
        ls_masks.append(mask_p)
    # reverse the lists because we want the lowest resolution image as index 0
    return ls_images[::-1], ls_masks[::-1]


def refine_predict(
    batch: dict,
    inpainter: torch.nn.Module,
    gpu_ids: str = "0",
    modulo: int = 8,
    n_iters: int = 15,
    lr: float = 0.002,
    min_side: int = 512,
    max_scales: int = 5,
    px_budget: int = 800000,
):
    """Refines the inpainting of the network

    Parameters
    ----------
    batch : dict
        image-mask batch, currently we assume the batchsize to be 1
    inpainter : torch.nn.Module
        the inpainting neural network
    gpu_ids : str
        the GPU ids of the machine to use. If only single GPU, use: "0,"
    modulo : int
        pad the image to ensure dimension % modulo == 0
    n_iters : int
        number of iterations of refinement for each scale
    lr : float
        learning rate
    min_side : int
        all sides of image on all scales should be >= min_side / sqrt(2)
    max_scales : int
        max number of downscaling scales for the image-mask pyramid
    px_budget : int
        pixels budget. Any image will be resized to satisfy height*width <= px_budget

    Returns
    -------
    torch.Tensor
        inpainted image of size (1,3,H,W)
    """
    inpainter = inpainter
    assert not inpainter.training
    # assert not inpainter.add_noise_kwargs
    # assert inpainter.concat_mask

    gpu_ids = [f"cuda:{gpuid}" for gpuid in gpu_ids.replace(" ", "").split(",") if gpuid.isdigit()]
    n_resnet_blocks = 0
    first_resblock_ind = 0
    found_first_resblock = False
    for idl in range(len(inpainter)):
        if isinstance_string(inpainter[idl], "FFCResnetBlock") or isinstance_string(
            inpainter[idl], "ResnetBlock"
        ):
            n_resnet_blocks += 1
            found_first_resblock = True
        elif not found_first_resblock:
            first_resblock_ind += 1
    resblocks_per_gpu = n_resnet_blocks // len(gpu_ids)

    devices = [torch.device(gpu_id) for gpu_id in gpu_ids]

    # split the model into front, and rear parts
    forward_front = inpainter[0:first_resblock_ind]
    forward_front.to(devices[0])
    forward_rears = []
    for idd in range(len(gpu_ids)):
        if idd < len(gpu_ids) - 1:
            forward_rears.append(
                inpainter[
                    first_resblock_ind
                    + resblocks_per_gpu * (idd) : first_resblock_ind
                    + resblocks_per_gpu * (idd + 1)
                ]
            )
        else:
            forward_rears.append(inpainter[first_resblock_ind + resblocks_per_gpu * (idd) :])
        forward_rears[idd].to(devices[idd])

    ls_images, ls_masks = _get_image_mask_pyramid(batch, min_side, max_scales, px_budget)
    image_inpainted = None

    for ids, (image, mask) in enumerate(zip(ls_images, ls_masks)):
        orig_shape = image.shape[2:]
        image = pad_tensor_to_modulo(image, modulo)
        mask = pad_tensor_to_modulo(mask, modulo)
        mask[mask >= 1e-8] = 1.0
        mask[mask < 1e-8] = 0.0
        image, mask = move_to_device(image, devices[0]), move_to_device(mask, devices[0])
        if image_inpainted is not None:
            image_inpainted = move_to_device(image_inpainted, devices[-1])
        image_inpainted = _infer(
            image,
            mask,
            forward_front,
            forward_rears,
            image_inpainted,
            orig_shape,
            devices,
            ids,
            n_iters,
            lr,
        )
        image_inpainted = image_inpainted[:, :, : orig_shape[0], : orig_shape[1]]
        # detach everything to save resources
        image = image.detach().cpu()
        mask = mask.detach().cpu()

    return image_inpainted
